{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Linear Language Model with Data Loader\n",
    "\n",
    "Status of Notebook: Work in Progress\n",
    "\n",
    "Difference from `loglin-lm.ipynb` is that we use a data loader to load the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import random\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import math\n",
    "import time\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download the Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# uncomment to download the datasets\n",
    "#!wget https://raw.githubusercontent.com/neubig/nn4nlp-code/master/data/ptb/test.txt\n",
    "#!wget https://raw.githubusercontent.com/neubig/nn4nlp-code/master/data/ptb/train.txt\n",
    "#!wget https://raw.githubusercontent.com/neubig/nn4nlp-code/master/data/ptb/valid.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Process the Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# function to read in data, process each line and split columns by \" ||| \"\n",
    "def read_data(filename):\n",
    "    data = []\n",
    "    with open(filename, \"r\") as f:\n",
    "        for line in f:\n",
    "            line = line.strip().split(\" \")\n",
    "            data.append(line)\n",
    "    return data\n",
    "\n",
    "# read the data\n",
    "train_data = read_data('data/ptb/train.txt')\n",
    "val_data = read_data('data/ptb/valid.txt')\n",
    "\n",
    "# creating the word and tag indices and special tokens\n",
    "word_to_index = {}\n",
    "index_to_word = {}\n",
    "word_to_index[\"<s>\"] = len(word_to_index)\n",
    "index_to_word[len(word_to_index)-1] = \"<s>\"\n",
    "word_to_index[\"<unk>\"] = len(word_to_index) # add <UNK> to dictionary\n",
    "index_to_word[len(word_to_index)-1] = \"<unk>\"\n",
    "\n",
    "# create word to index dictionary and tag to index dictionary from data\n",
    "def create_dict(data, check_unk=False):\n",
    "    for line in data:\n",
    "        for word in line:\n",
    "            if check_unk == False:\n",
    "                if word not in word_to_index:\n",
    "                    word_to_index[word] = len(word_to_index)\n",
    "                    index_to_word[len(word_to_index)-1] = word\n",
    "            \n",
    "            # has no effect because data already comes with <unk>\n",
    "            # should work with data without <unk> already processed\n",
    "            else: \n",
    "                if word not in word_to_index:\n",
    "                    word_to_index[word] = word_to_index[\"<unk>\"]\n",
    "                    index_to_word[len(word_to_index)-1] = word\n",
    "\n",
    "create_dict(train_data)\n",
    "create_dict(val_data, check_unk=True)\n",
    "\n",
    "# create word and tag tensors from data\n",
    "def create_tensor(data):\n",
    "    for line in data:\n",
    "        yield [word_to_index[word] for word in line]\n",
    "\n",
    "train_data = [*create_tensor(train_data)]\n",
    "val_data = [*create_tensor(val_data)]\n",
    "\n",
    "number_of_words = len(word_to_index)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Convert data to PyTorch Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "class PTB(Dataset):\n",
    "    def __init__(self, data):\n",
    "        self.data = data\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return torch.as_tensor(self.data[idx])\n",
    "\n",
    "train_dataset = PTB(train_data)\n",
    "val_dataset = PTB(val_data)\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)\n",
    "val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In our implementation we are using batched training. There are a few differences from the original implementation found [here](https://github.com/neubig/nn4nlp-code/blob/master/02-lm/loglin-lm.py). "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "## define the model\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "# length of the n-gram\n",
    "N = 2\n",
    "\n",
    "# logilinear model\n",
    "class LogLinear(nn.Module):\n",
    "    def __init__(self, number_of_words, ngram_length):\n",
    "        super(LogLinear, self).__init__()\n",
    "\n",
    "        # different lookups for each position in the n-gram\n",
    "        self.embeddings = nn.ModuleList([nn.Embedding(number_of_words, number_of_words) for _ in range(ngram_length)])\n",
    "        self.bias = torch.zeros(number_of_words, requires_grad=True).type(torch.FloatTensor).to(device)\n",
    "\n",
    "        # initialize\n",
    "        for i in range(N):\n",
    "            nn.init.xavier_uniform_(self.embeddings[i].weight)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # calculate score\n",
    "        embs = torch.cat([lookup(x) for x, lookup in zip(x.T, self.embeddings)]).view(N, x.shape[0], -1) # N x batch_size x embedding_size\n",
    "        embs = torch.sum(embs, dim=0) # batch_size x embedding_size\n",
    "        scores = embs + self.bias\n",
    "        \n",
    "        return scores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Settings and Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LogLinear(number_of_words, N)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.1)\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    model.to(device)\n",
    "\n",
    "# function to calculate the sentence loss\n",
    "def calc_sent_loss(sent):\n",
    "    S = word_to_index[\"<s>\"]\n",
    "    \n",
    "    # initial history is equal to end of sentence symbols\n",
    "    hist = [S] * N\n",
    "    \n",
    "    # collect all target and histories\n",
    "    all_targets = []\n",
    "    all_histories = []\n",
    "    \n",
    "    # step through the sentence, including the end of sentence token\n",
    "    for next_word in sent + torch.Tensor([S]):\n",
    "        all_histories.append(list(hist))\n",
    "        all_targets.append(next_word)\n",
    "        hist = hist[1:] + [next_word]\n",
    "\n",
    "    logits = model(torch.LongTensor(all_histories).to(device))\n",
    "    loss = criterion(logits, torch.LongTensor(all_targets).to(device))\n",
    "\n",
    "    return loss\n",
    "\n",
    "MAX_LEN = 100\n",
    "# Function to generate a sentence\n",
    "def generate_sent():\n",
    "    S = word_to_index[\"<s>\"]\n",
    "    hist = [S] * N\n",
    "    sent = []\n",
    "    while True:\n",
    "        logits = model(torch.LongTensor([hist]).to(device))\n",
    "        p = torch.nn.functional.softmax(logits) # 1 x number_of_words\n",
    "        next_word = p.multinomial(num_samples=1).item()\n",
    "        if next_word == S or len(sent) == MAX_LEN:\n",
    "            break\n",
    "        sent.append(next_word)\n",
    "        hist = hist[1:] + [next_word]\n",
    "    return sent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--finished 5000 sentences\n",
      "--finished 10000 sentences\n",
      "--finished 15000 sentences\n",
      "--finished 20000 sentences\n",
      "--finished 25000 sentences\n",
      "--finished 30000 sentences\n",
      "--finished 35000 sentences\n",
      "--finished 40000 sentences\n",
      "iter 0: train loss/word=9.0947, ppl=8907.6500\n",
      "iter 0: dev loss/word=9.7668, ppl=17444.9221, time=1.76s\n",
      "in this case of the trade deficit of the globe weeks columnist months <unk> from a <unk> character succeed reflects as an effort will teaching mr. chestman was essentially flat to deal with the board is this time the <unk> an international <unk> machines are n't being any at this time you were n't disclosed this week it to take over a company said it will introduce a new york <unk> that since friday 's sharp swings in the field sales were down on N at a <unk> company said it will invest in quarterly profit by the new securities\n",
      "on monday at N yen $ N million navy contract for advanced there were <unk> when he 's no decision has been done by the bush administration has of new hampshire preferred holders total package that includes <unk> <unk> is that the full of only N to rise N N months of sept. N N share of $ N down N N N to N N to N this year and sales increased nearly N million shares outstanding as of that japan is starting in france spain italy and turkey late 1960s commissioner worthy of a food <unk> rose to\n",
      "speaking to build a giant <unk> corp. new york stock exchange during the first nine months <unk> charges for example banks <unk> station and gas production at the hands of our crowd efforts have been trying to plot against him the chief received the payment problem of that big institutions were never going to be loyal to try to <unk> units in the federal reserve <unk> onto the field <unk> with any securities by the irs recently said it will introduce a new york <unk> that replaced become known as <unk> <unk> resources inc. <unk> between what 's own decision\n",
      "these funds will be a <unk> it a better business he <unk> the market after the N <unk> after an a computer company for the defense plan and will come from a <unk> gene was missing acting expired award clients ' portfolios are the close of N million navy contract for <unk> an analyst with <unk> by saturday morning hat in big trading houses analysts expected to seek to clean up all <unk> says he is <unk> the best thing you do n't even the clutter of gold for current delivery of $ N million of $ N a vehicle\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/nlp/lib/python3.7/site-packages/ipykernel_launcher.py:38: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "advertisers and advertising rates for the s&p N issue of the issues <unk> pace with rival very small amounts to veto the constitution <unk> sen coordinator of the big three <unk> the las vegas 's increased <unk> activity is only one or for one thing is important as of as many as N million navy contract for the government is <unk> by mr. <unk> has <unk> business conditions and the earnings or N on the firm of that this is that mr. gorbachev 's economic activity and only half of the proposal to reduce interest rates in the <unk> he\n",
      "--finished 5000 sentences\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_1861/185239032.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      8\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0msent_id\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msent\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# CHANGE to all train_data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m         \u001b[0mmy_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcalc_sent_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msent\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m         \u001b[0mtrain_loss\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mmy_loss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_1861/2590246674.py\u001b[0m in \u001b[0;36mcalc_sent_loss\u001b[0;34m(sent)\u001b[0m\n\u001b[1;32m     23\u001b[0m         \u001b[0mhist\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhist\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mnext_word\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m     \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLongTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mall_histories\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     26\u001b[0m     \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLongTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mall_targets\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# start training\n",
    "for ITER in range (10): # CHANGE to 100\n",
    "    # training\n",
    "\n",
    "    model.train()\n",
    "    train_words, train_loss  = 0, 0.0\n",
    "    for sent_id, sent in enumerate(train_loader):\n",
    "        \n",
    "        my_loss = calc_sent_loss(sent[0])\n",
    "        \n",
    "        train_loss += my_loss.item()\n",
    "        train_words += len(sent)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        my_loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if (sent_id+1) % 5000 == 0:\n",
    "            print(\"--finished %r sentences\" % (sent_id+1))\n",
    "    print(\"iter %r: train loss/word=%.4f, ppl=%.4f\" % (ITER, train_loss/train_words, math.exp(train_loss/train_words)))\n",
    "\n",
    "    # evaluation\n",
    "    model.eval()\n",
    "    dev_words, dev_loss = 0, 0.0\n",
    "    start = time.time()\n",
    "    for sent_id, sent in enumerate(val_loader):\n",
    "        my_loss = calc_sent_loss(sent[0])\n",
    "        dev_loss += my_loss.item()\n",
    "        dev_words += len(sent)\n",
    "    print(\"iter %r: dev loss/word=%.4f, ppl=%.4f, time=%.2fs\" % (ITER, dev_loss/dev_words, math.exp(dev_loss/dev_words), time.time()-start))\n",
    "\n",
    "    # Generate a few sentences\n",
    "    for _ in range(5):\n",
    "        sent = generate_sent()\n",
    "        print(\" \".join([index_to_word[x] for x in sent]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nlp",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "154abf72fb8cc0db1aa0e7366557ff891bff86d6d75b7e5f2e68a066d591bfd7"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
